{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第7章 图像数据处理\n",
    "在第6章中详细介绍了CNN，并提到CNN给图像识别技术带来了突破性进展。这一章将从另外一个维度来进一步提升图像识别的精度以及训练的速度。\n",
    "\n",
    "喜欢摄影的读者都知道图像的亮度、对比度等属性对图像的影响是非常大的，相同物体在不同亮度、对比度下差别非常大。然而在很多图像识别问题中，这些因素都不应该影响最后的识别结果。所以**本章将介绍如何对图像数据进行预处理使训练得到的神经网络模型尽可能小地被无关因素所影响。**\n",
    "\n",
    "但与此同时，复杂的预处理过程可能导致训练效率的下降。为了减小预处理对于训练速度的影响，在本章中也将详细地介绍TensorFlow 中**多线程处理输入数据**的解决方案。\n",
    "\n",
    "## 7.1 TFRecord输入数据格式\n",
    "6.5节中给出了一个程序来处理花朵分类的数据。在这个程序中，使用了一个从类别名称到所有数据列表的词典来维护图像和类别的关系。这种方式的**可扩展性非常差**，当数据来源更加复杂、每一个样例中的信息更加丰富之后，这种方式就很难有效地记录输入数据中的信息了。于是TensorFlow提供了一种统一的格式来存储数据——**TFRecord**。\n",
    "\n",
    "### 7.1.1 TFRecord格式介绍\n",
    "TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下给出了tf.train.Example的定义："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "message Example {\n",
    "    Feature features = 1;\n",
    "};\n",
    "\n",
    "message Features {\n",
    "    map<string, Feature> feature = 1;\n",
    "};\n",
    "\n",
    "message Feature {\n",
    "    oneof kind {\n",
    "        BytesList bytes_list = 1;\n",
    "        FloatList float_list = 2;\n",
    "        Int64List int64_list = 3;\n",
    "  }\n",
    "};\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看出tf.train.Example的数据结构是比较简洁的。**tf.train.Example中包含了一个从属性名称到取值的字典**。其中属性名称为一个字符串，属性的取值可以为字符串（BytesList）、实数列表（FloatList）或者整数列表（Int64List）。比如将一张解码前的图像存为一个字符串，图像所对应的类别编号存为整数列表。\n",
    "\n",
    "** 7.1.2 TFRecord样例程序**\n",
    "以下程序给出了以MNIST数据集为例，如何将其转化为TFRecord格式，及读取这个TFRecord格式。\n",
    "\n",
    "**1. 写入TFRecord格式**："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-1-d2e8c23c76e8>:26: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../../datasets/MNIST_data\\train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../../datasets/MNIST_data\\train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting ../../datasets/MNIST_data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting ../../datasets/MNIST_data\\t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "TFRecord训练文件已保存。\n",
      "TFRecord测试文件已保存。\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import numpy as np\n",
    "\n",
    "# 生成整数型的属性\n",
    "def _int64_feature(value):\n",
    "    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))\n",
    "\n",
    "# 生成字符串型的属性\n",
    "def _bytes_feature(value):\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "\n",
    "# 将数据转化为tf.train.Example格式。\n",
    "def _make_example(pixels, label, image):\n",
    "    image_raw = image.tostring()\n",
    "    example = tf.train.Example(features=tf.train.Features(feature={\n",
    "        'pixels': _int64_feature(pixels),\n",
    "        'label': _int64_feature(np.argmax(label)),\n",
    "        'image_raw': _bytes_feature(image_raw)\n",
    "    }))\n",
    "    return example\n",
    "\n",
    "# 读取mnist训练数据。\n",
    "mnist = input_data.read_data_sets(\"../../datasets/MNIST_data\", \n",
    "                                  dtype=tf.uint8, \n",
    "                                  one_hot=True)\n",
    "images = mnist.train.images\n",
    "labels = mnist.train.labels\n",
    "pixels = images.shape[1]\n",
    "num_examples = mnist.train.num_examples\n",
    "\n",
    "# 输出包含训练数据的TFRecord文件。\n",
    "# 先创建一个writer来写TFRecord文件\n",
    "with tf.python_io.TFRecordWriter(path=\"output.tfrecords\") as writer:\n",
    "    for index in range(num_examples):\n",
    "        # 将一个样例转化为Example Protocol Buffer，并将所有的信息写入这个数据结构\n",
    "        example = _make_example(pixels, labels[index], images[index])\n",
    "        # 将一个Example写入TFRecord文件\n",
    "        writer.write(example.SerializeToString())\n",
    "print(\"TFRecord训练文件已保存。\")\n",
    "\n",
    "# 读取mnist测试数据。\n",
    "images_test = mnist.test.images\n",
    "labels_test = mnist.test.labels\n",
    "pixels_test = images_test.shape[1]\n",
    "num_examples_test = mnist.test.num_examples\n",
    "\n",
    "# 输出包含测试数据的TFRecord文件。\n",
    "with tf.python_io.TFRecordWriter(path=\"output_test.tfrecords\") as writer:\n",
    "    for index in range(num_examples_test):\n",
    "        example = _make_example(\n",
    "            pixels_test, labels_test[index], images_test[index])\n",
    "        writer.write(example.SerializeToString())\n",
    "print(\"TFRecord测试文件已保存。\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当数据量较大时，也可以将数据写入多个TFRecord文件。TensorFlow对文件列表读取数据提供了很好的支持，这在7.3.2节中会介绍。\n",
    "\n",
    "**2. 读取TFRecord文件**:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\python\\training\\input.py:187: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "To construct input pipelines, use the `tf.data` module.\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\python\\training\\input.py:187: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "To construct input pipelines, use the `tf.data` module.\n",
      "WARNING:tensorflow:From <ipython-input-1-2227bb390bc3>:32: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "To construct input pipelines, use the `tf.data` module.\n",
      "8 784 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  22  69\n",
      " 148 210 253 156 122   7   0   0  18   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0 100 221 252 252 253 252 252 252 113   0   0\n",
      " 185   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  31 221\n",
      " 252 252 244 232 231 251 252 252  98   0 211   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0 148 252 247 162  49   0   0  86 205 252\n",
      " 106   0 185   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      " 236 252 187   0   0   0   0   0  64 252 106   0 106   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0 254 253 144   0   0   0   0   0\n",
      "  27 229  62   0 107   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0 174 252 231   0   0   0   0  52 190 242 185   0 106   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0  30 212 251 135  22  84\n",
      " 206 242 252 252 250  58 115   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0  85 252 252 252 253 252 252 252 199 128  21 211   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0  62 199 252 252\n",
      " 252 253 252 190  42   7   0   0 211   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0  80 210 253 253 253 253 255 107   0   0   0   0   0\n",
      " 212   0   0   0   0   0   0   0   0   0   0   0   0   0  22 225 253 252\n",
      " 245 168 239 253  98   0   0   0   0   0 211   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0  29 237 252 241 117  19   0 149 253 175   2   0   0\n",
      "   0   0 106   0   0   0   0   0   0   0   0   0   0   0   0 153 252 226\n",
      "  80   0   0   0   0 183 252  91   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0 232 252  59   0   0   0   0   0 104 252 214\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 233\n",
      " 253  42   0   0   0   0   0  18 217 236  14   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0 197 252 156  36   0   0   0   0  15\n",
      " 211 252  84   0   0   0  71   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0 119 252 252 242 197 127 127 127 237 252 233   7   0   0   0 211   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0  16 231 252 253 252 252 252\n",
      " 252 253 252 205   0   0   0   0 158   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0  28 129 209 252 252 252 252 191 112  21   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "   0   0   0   0   0   0   0   0   0   0]\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# 创建一个reader来读取TFRecord文件中的样例\n",
    "reader = tf.TFRecordReader()\n",
    "# 创建一个列表来维护输入文件列表，在7.3.2节中会详细介绍\n",
    "filename_queue = tf.train.string_input_producer([\"output.tfrecords\"])\n",
    "\n",
    "# 从文件中读取一个样例。也可以使用read_up_to函数一次性读取多个样例\n",
    "_, serialized_example = reader.read(filename_queue)\n",
    "\n",
    "# 解析读取的样例。如果需要解析多个样例，可使用parse_example函数\n",
    "features = tf.parse_single_example(\n",
    "    serialized_example,\n",
    "    features={\n",
    "        # TensorFlow提供两种不同的属性解析方法。\n",
    "        # 一种是方法是tf.FixedLenFeature,这种方法解析的结果为一个Tensor。\n",
    "        # 另一种方法是tf.VarLenFeature，这种方法得到的解析结果为SparseTensor，用于处理稀疏数据。\n",
    "        # 这里解析数据的格式需要和上面程序写入数据的格式一致。\n",
    "        'image_raw':tf.FixedLenFeature([],tf.string),\n",
    "        'pixels':tf.FixedLenFeature([],tf.int64),\n",
    "        'label':tf.FixedLenFeature([],tf.int64)\n",
    "    })\n",
    "\n",
    "# tf.decode_raw函数可以将字符串解析成图像对应的像素数组\n",
    "images = tf.decode_raw(features['image_raw'],tf.uint8)\n",
    "labels = tf.cast(features['label'],tf.int32)\n",
    "pixels = tf.cast(features['pixels'],tf.int32)\n",
    "\n",
    "sess = tf.Session()\n",
    "# 启动多线程处理输入数据。7.3节将会更详细介绍多线程\n",
    "coord = tf.train.Coordinator()\n",
    "threads = tf.train.start_queue_runners(sess=sess, coord=coord)\n",
    "\n",
    "# 每次运行可以读取TFRecord文件中的一个样例。当所有样例都读完之后，在此样例中程序会再重头读取\n",
    "for i in range(10):\n",
    "    image, label, pixel = sess.run([images, labels, pixels])\n",
    "print(label, pixel, image)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
